import rpy2.robjects as robjects
import rpy2.robjects.numpy2ri
import numpy as np
import tensorly as tl
from tensorly.tenalg import multi_mode_dot as mmd

tl.set_backend('numpy')
rpy2.robjects.numpy2ri.activate()
robjects.r.source('Codes_Spectral_Matrix.r')

Tensor = np.load('wat_edge.npy')
Node_set = np.load('wat_node.npy')
Node_set[58] = "Cote d'Ivoire"
Tensor[Tensor > 0] = 1


M = 32

layer_deg = Tensor.sum(axis=(0, 1))
index = layer_deg.argsort()
index = list(index)
index.reverse()
index = index[:M]
NT = Tensor[:, :, index]

deg = NT.sum(axis=(1, 2))
ind = np.where(deg/M > 9)[0]
# Node_set = Node_set[ind]
NT = NT[ind]
NT = NT[:, ind, :]
n = NT.shape[0]


K = 6
fraction_of_training = 0.8
Repetition = 50

link_prediciton_error = []
for i in range(Repetition):
    B = np.random.binomial(1, fraction_of_training, (n, n, M))
    for m in range(M):
        B[:, :, m] = np.triu(B[:, :, m]) + np.triu(B[:, :, m], 1).T
    B0 = np.zeros(B.shape)
    for m in range(M):
        B0[:, :, m] = np.triu(B[:, :, m])
    A_speck = NT * B
    X = []
    for m in range(M):
        X.append(A_speck[:, :, m])
    psi_hat = np.array(robjects.r.speck(X, n, K)) - 1
    Z = np.eye(K)[psi_hat]
    ind_list = []
    for k in range(K):
        ind_list.append(np.where(psi_hat == k)[0])
    Core = np.zeros((K, K, M))
    for k1 in range(K):
        for k2 in range(K):
            if len(ind_list[k1]) * len(ind_list[k2]) != 0:
                Core[k1, k2, :] = A_speck[ind_list[k1], :, :][:, ind_list[k2], :].sum(axis=(0, 1))\
                                  /(len(ind_list[k1]) * len(ind_list[k2]))
            else:
                Core[k1, k2, :] = np.zeros(M)
    P = mmd(Core, [Z]*2, modes=[0, 1])
    A_hat = np.random. binomial(1, P)
    B0_tilde = 1 - B0
    for m in range(M):
        B0_tilde[:, :, m] = np.triu(B0_tilde[:, :, m])
    link_prediciton_error.append((np.abs((A_hat - A_speck) * B0_tilde)).sum() / B0_tilde.sum())

print("The averaged link prediction error by SPECK over", Repetition, "independent replications is",
      np.array(link_prediciton_error).mean(), "with standard error",
      np.array(link_prediciton_error).std()/np.sqrt(Repetition))
